
#CNN 和 线性


import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

#加载数据
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=ToTensor())
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"


#定义网络
class CnnNetwork(torch.nn.Module):
    def __init__(self):
        super(CnnNetwork, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 128, 3, 1)
        self.conv2 = torch.nn.Conv2d(128, 64, 2, 1)
        self.conv3 = torch.nn.Conv2d(64, 32, 3, 1)
        self.linear1 = torch.nn.Linear(1 * 32 * 5 * 5, 256)
        self.linear2 = torch.nn.Linear(256, 128)
        self.dropout = torch.nn.Dropout(0.5)
        self.linear3 = torch.nn.Linear(128, 10)


    #前向传播
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 1)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv3(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 1 * 32 * 5 * 5)
        x = self.linear1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = torch.nn.functional.relu(x)
        x = self.linear3(x)
        return x

#实例化网络
model = CnnNetwork().to(device)
#损失函数
loss_fn = torch.nn.CrossEntropyLoss()
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    lossall = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        lossall = lossall + loss
    print(f"Loss :{(lossall/size)}")

model.train()
for i in range(10):
    train(train_dataloader, model, loss_fn, optimizer)

torch.save(model.state_dict(), "../Model/model.pth")

# 加载模型
model_test = CnnNetwork().to(device)
model_test.load_state_dict(torch.load("../Model/model.pth", weights_only=True), strict=True)
model_test.eval()


size = len(test_dataloader.dataset)

with torch.no_grad():
    correct = 0
    size = len(test_dataloader.dataset)
    for batch, (X, y) in enumerate(test_dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    print(f"Test Error: \n Accuracy: {(100*correct/size):>0.1f}%\n")

























































